{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "78365427-e6e1-4ba3-ad42-69d9c8620a00",
   "metadata": {},
   "source": [
    "# DistilBERT Classifier as Feature Extractor Using Embetter"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2c384b90-6074-4a96-b84d-3f1f29e10b7f",
   "metadata": {},
   "source": [
    "In this feature-based approach, we are using the embeddings from a pretrained transformer to train a random forest and logistic regression model in scikit-learn:\n",
    "\n",
    "![](figures/feature-extractor.jpeg)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "9eea2bca-b66f-4751-8af9-0c1e7631563e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# pip install transformers datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "33541bc3-0e07-4808-b6b6-4d7578814ab6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# conda install sklearn --yes"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9933a986-37ae-49a1-9da4-2dd0f54b2d61",
   "metadata": {},
   "source": [
    "In addition, we will be using the [embetter](https://github.com/koaning/embetter) scikit-learn library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a6ab918a-b0a2-4e39-a316-fa14b1fd3207",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch       : 1.12.1\n",
      "transformers: 4.23.1\n",
      "datasets    : 2.6.1\n",
      "sklearn     : 0.0\n",
      "\n",
      "conda environment: dl-fundamentals\n",
      "\n"
     ]
    }
   ],
   "source": [
    "%load_ext watermark\n",
    "%watermark --conda -p torch,transformers,datasets,sklearn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "1e58e25b-8ef2-4f03-87b2-9fea4728aef3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ecf02113-10e8-41a8-b7cb-029d24cb0593",
   "metadata": {
    "tags": []
   },
   "source": [
    "# 1 Loading the Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f871d7d-d973-4f34-a014-5961e173280a",
   "metadata": {},
   "source": [
    "The IMDB movie review dataset consists of 50k movie reviews with sentiment label (0: negative, 1: positive)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "31944168-80db-47ed-95d1-c40459e16343",
   "metadata": {},
   "source": [
    "## 1a) Load from `datasets` Hub"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "53160e55-e19f-40b9-bfe3-12e15bc2835b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from datasets import list_datasets, load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "4db26996-6616-4872-a894-448c9669c1e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "# list_datasets()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "eef7f55e-f6a3-40e7-85ff-270ef11aca70",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Found cached dataset imdb (/home/raschka/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8992eba4dc384426a34944a7f102123f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/3 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DatasetDict({\n",
      "    train: Dataset({\n",
      "        features: ['text', 'label'],\n",
      "        num_rows: 25000\n",
      "    })\n",
      "    test: Dataset({\n",
      "        features: ['text', 'label'],\n",
      "        num_rows: 25000\n",
      "    })\n",
      "    unsupervised: Dataset({\n",
      "        features: ['text', 'label'],\n",
      "        num_rows: 50000\n",
      "    })\n",
      "})\n"
     ]
    }
   ],
   "source": [
    "imdb_data = load_dataset(\"imdb\")\n",
    "print(imdb_data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2ed3fec2-a35a-4dc0-b863-e0c5c1799654",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text': \"This film is terrible. You don't really need to read this review further. If you are planning on watching it, suffice to say - don't (unless you are studying how not to make a good movie).<br /><br />The acting is horrendous... serious amateur hour. Throughout the movie I thought that it was interesting that they found someone who speaks and looks like Michael Madsen, only to find out that it is actually him! A new low even for him!!<br /><br />The plot is terrible. People who claim that it is original or good have probably never seen a decent movie before. Even by the standard of Hollywood action flicks, this is a terrible movie.<br /><br />Don't watch it!!! Go for a jog instead - at least you won't feel like killing yourself.\",\n",
       " 'label': 0}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "imdb_data[\"train\"][99]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7439730-f7fc-4b08-b7e6-b32ff72c9723",
   "metadata": {
    "tags": []
   },
   "source": [
    "## 1b) Load from local directory"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca5f590b-7e8b-492a-b937-33f3ea49036c",
   "metadata": {},
   "source": [
    "The IMDB movie review set can be downloaded from http://ai.stanford.edu/~amaas/data/sentiment/. After downloading the dataset, decompress the files.\n",
    "\n",
    "A) If you are working with Linux or MacOS X, open a new terminal window cd into the download directory and execute\n",
    "\n",
    "    tar -zxf aclImdb_v1.tar.gz\n",
    "\n",
    "B) If you are working with Windows, download an archiver such as 7Zip to extract the files from the download archive."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "07f3cf11-550d-4544-b9e1-9db656a665a3",
   "metadata": {},
   "source": [
    "C) Use the following code to download and unzip the dataset via Python"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "980e46a7-cfa0-420e-820d-c1425de5ca55",
   "metadata": {},
   "source": [
    "**Download the movie reviews**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "1c883c5d-27ee-4eac-bbd5-c31c70e0aaea",
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "import tarfile\n",
    "import time\n",
    "import urllib.request\n",
    "\n",
    "source = \"http://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
    "target = \"aclImdb_v1.tar.gz\"\n",
    "\n",
    "if os.path.exists(target):\n",
    "    os.remove(target)\n",
    "\n",
    "\n",
    "def reporthook(count, block_size, total_size):\n",
    "    global start_time\n",
    "    if count == 0:\n",
    "        start_time = time.time()\n",
    "        return\n",
    "    duration = time.time() - start_time\n",
    "    progress_size = int(count * block_size)\n",
    "    speed = progress_size / (1024.0**2 * duration)\n",
    "    percent = count * block_size * 100.0 / total_size\n",
    "\n",
    "    sys.stdout.write(\n",
    "        f\"\\r{int(percent)}% | {progress_size / (1024.**2):.2f} MB \"\n",
    "        f\"| {speed:.2f} MB/s | {duration:.2f} sec elapsed\"\n",
    "    )\n",
    "    sys.stdout.flush()\n",
    "\n",
    "\n",
    "if not os.path.isdir(\"aclImdb\") and not os.path.isfile(\"aclImdb_v1.tar.gz\"):\n",
    "    urllib.request.urlretrieve(source, target, reporthook)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "be9ad2f8-7cb2-4c4f-937d-7d3f4315fbae",
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.isdir(\"aclImdb\"):\n",
    "\n",
    "    with tarfile.open(target, \"r:gz\") as tar:\n",
    "        tar.extractall()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "75b2c486-1a19-4ddb-b073-6958050e525b",
   "metadata": {},
   "source": [
    "**Convert them to a pandas DataFrame and save them as CSV**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "bd725b35-2044-4f0b-9516-046f8690dc20",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|███████████████████████████████████████████████████████| 50000/50000 [00:55<00:00, 893.83it/s]\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import sys\n",
    "\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from packaging import version\n",
    "from tqdm import tqdm\n",
    "\n",
    "# change the `basepath` to the directory of the\n",
    "# unzipped movie dataset\n",
    "\n",
    "basepath = \"aclImdb\"\n",
    "\n",
    "labels = {\"pos\": 1, \"neg\": 0}\n",
    "\n",
    "df = pd.DataFrame()\n",
    "\n",
    "with tqdm(total=50000) as pbar:\n",
    "    for s in (\"test\", \"train\"):\n",
    "        for l in (\"pos\", \"neg\"):\n",
    "            path = os.path.join(basepath, s, l)\n",
    "            for file in sorted(os.listdir(path)):\n",
    "                with open(os.path.join(path, file), \"r\", encoding=\"utf-8\") as infile:\n",
    "                    txt = infile.read()\n",
    "\n",
    "                if version.parse(pd.__version__) >= version.parse(\"1.3.2\"):\n",
    "                    x = pd.DataFrame(\n",
    "                        [[txt, labels[l]]], columns=[\"review\", \"sentiment\"]\n",
    "                    )\n",
    "                    df = pd.concat([df, x], ignore_index=False)\n",
    "\n",
    "                else:\n",
    "                    df = df.append([[txt, labels[l]]], ignore_index=True)\n",
    "                pbar.update()\n",
    "df.columns = [\"text\", \"label\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "f367f045-0da2-495c-8a24-115863925a15",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "np.random.seed(0)\n",
    "df = df.reindex(np.random.permutation(df.index))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e163e017-9674-4378-8454-69b8912a09f4",
   "metadata": {},
   "source": [
    "**Basic datasets analysis and sanity checks**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "2dd205ce-be31-410c-9d6b-d39895a9c634",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Class distribution:\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "array([25000, 25000])"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(\"Class distribution:\")\n",
    "np.bincount(df[\"label\"].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "93ea4c92-1e3c-497c-ae0e-e114437fd80c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(4, 173.0, 2470)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text_len = df[\"text\"].apply(lambda x: len(x.split()))\n",
    "text_len.min(), text_len.median(), text_len.max() "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9a95465c-35d6-4638-a3c1-1d8dfb3ff60b",
   "metadata": {},
   "source": [
    "**Split data into training, validation, and test sets**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "cd5f5326-8062-407f-8679-d701d9e2c169",
   "metadata": {},
   "outputs": [],
   "source": [
    "df_shuffled = df.sample(frac=1, random_state=1).reset_index()\n",
    "\n",
    "df_train = df_shuffled.iloc[:35_000]\n",
    "df_val = df_shuffled.iloc[35_000:40_000]\n",
    "df_test = df_shuffled.iloc[40_000:]\n",
    "\n",
    "df_train.to_csv(\"train.csv\", index=False, encoding=\"utf-8\")\n",
    "df_val.to_csv(\"validation.csv\", index=False, encoding=\"utf-8\")\n",
    "df_test.to_csv(\"test.csv\", index=False, encoding=\"utf-8\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "d4cc8710-e962-4513-ad70-01ddbe535e0a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>index</th>\n",
       "      <th>text</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>When we started watching this series on cable,...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>Steve Biko was a black activist who tried to r...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>My short comment for this flick is go pick it ...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>As a serious horror fan, I get that certain ma...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>Robert Cummings, Laraine Day and Jean Muir sta...</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   index                                               text  label\n",
       "0      0  When we started watching this series on cable,...      1\n",
       "1      0  Steve Biko was a black activist who tried to r...      1\n",
       "2      0  My short comment for this flick is go pick it ...      1\n",
       "3      0  As a serious horror fan, I get that certain ma...      0\n",
       "4      0  Robert Cummings, Laraine Day and Jean Muir sta...      1"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9185130-6530-44aa-8579-1505227e1be3",
   "metadata": {},
   "source": [
    "# 2 Train Model on Embeddings (Extracted Features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "0c7819bb-6778-4c47-a061-e99c9542385c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from sklearn.pipeline import make_pipeline \n",
    "from sklearn.linear_model import LogisticRegression\n",
    "\n",
    "from embetter.text import SentenceEncoder\n",
    "\n",
    "classifier = make_pipeline(\n",
    "  SentenceEncoder(\"distiluse-base-multilingual-cased-v2\"),\n",
    "  LogisticRegression()\n",
    ")\n",
    "\n",
    "classifier.fit(df_train[\"text\"].values, df_train[\"label\"].values);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d569d831-ea20-4cb1-a454-34e88719a6af",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifier.score(df_val[\"text\"].values, df_val[\"label\"].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "09a31990-004c-423f-ab3e-79720e491c4c",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.8032"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "classifier.score(df_test[\"text\"].values, df_test[\"label\"].values)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b095a33d-6a25-4694-9aab-b9e99fac2ca9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
